#!/usr/bin/env python
# -*- coding: utf-8 -*-

from hmmlearn.hmm import GaussianHMM
import numpy as np
from matplotlib import cm, pyplot as plt
from numpy import nan_to_num
from pylab import mpl
import math
from typing import List
import pandas as pd

from web.models.commodity_future_date_data import CommodityFutureDateData
from web.manager.log_manager import LogManager
from web.task.base_task import BaseTask
from web.util.math_util import MathUtil

Logger = LogManager.get_logger(__name__)


class HmmTask(BaseTask):
    """
    隐马尔科夫模型量化交易
    参考资料：https://www.jtqh.com/show-33-2779-1.html
    """

    def do_task(self):
        """
        隐马尔科夫模型量化交易
        """

        filter_dict = {"code": "PM"}
        order_by_list = ['transaction_date']
        commodity_future_date_data_list = self.commodity_future_date_data_service.commodity_future_date_data_dao.find_list(
            filter_dict, dict(), order_by_list)
        if commodity_future_date_data_list is not None and len(commodity_future_date_data_list) > 0:
            # 大跌
            great_drop = 0
            # 下跌
            fall = 0
            # 平稳
            stable = 0
            # 上涨
            rise = 0
            # 大涨
            rising_sharply = 0
            # 大跌列表
            great_drop_list = list()
            # 下跌列表
            fall_list = list()
            # 平稳列表
            stable_list = list()
            # 上涨列表
            rise_list = list()
            # 大涨列表
            rising_sharply_list = list()
            # 收盘价列表
            close_price_list = list()
            # 等级
            choice = 5
            # 分类列表
            class_list = list()

            for commodity_future_date_data in commodity_future_date_data_list:
                close_price_list.append(commodity_future_date_data.close_price)
            print(close_price_list)

            # 均值标准差分级
            grade_list = MathUtil.deviate_grade_by_mean_and_standard(close_price_list, choice)
            print(grade_list)
            for i in close_price_list:
                if 0 < i < grade_list[1]:
                    great_drop = great_drop + 1
                    great_drop_list.append(i)
                    class_list.append("a")
                elif grade_list[1] <= i < grade_list[2]:
                    fall = fall + 1
                    fall_list.append(i)
                    class_list.append("b")
                elif grade_list[2] <= i < grade_list[3]:
                    stable = stable + 1
                    stable_list.append(i)
                    class_list.append("c")
                elif grade_list[3] <= i < grade_list[4]:
                    rise = rise + 1
                    rise_list.append(i)
                    class_list.append("d")
                elif i >= grade_list[4]:
                    rising_sharply = rising_sharply + 1
                    rising_sharply_list.append(i)
                    class_list.append("e")
            print("******************不同状态的数量*********************")
            print("大跌个数为：", great_drop)
            print("下跌个数为：", fall)
            print("稳定的个数为：", stable)
            print("上涨的个数为：", rise)
            print("大涨的个数为：", rising_sharply)
            print(class_list)

            # 转移频数矩阵
            a_list = [0, 0, 0, 0, 0]
            b_list = [0, 0, 0, 0, 0]
            c_list = [0, 0, 0, 0, 0]
            d_list = [0, 0, 0, 0, 0]
            e_list = [0, 0, 0, 0, 0]
            for index, value in enumerate(class_list):
                if index == len(class_list) - 1:
                    break
                if value == "a" and class_list[index + 1] == "a":
                    a_list[0] = a_list[0] + 1
                if value == "a" and class_list[index + 1] == "b":
                    a_list[1] = a_list[1] + 1
                if value == "a" and class_list[index + 1] == "c":
                    a_list[2] = a_list[2] + 1
                if value == "a" and class_list[index + 1] == "d":
                    a_list[3] = a_list[3] + 1
                if value == "a" and class_list[index + 1] == "e":
                    a_list[4] = a_list[4] + 1
                if value == "b" and class_list[index + 1] == "a":
                    b_list[0] = b_list[0] + 1
                if value == "b" and class_list[index + 1] == "b":
                    b_list[1] = b_list[1] + 1
                if value == "b" and class_list[index + 1] == "c":
                    b_list[2] = b_list[2] + 1
                if value == "b" and class_list[index + 1] == "d":
                    b_list[3] = b_list[3] + 1
                if value == "b" and class_list[index + 1] == "e":
                    b_list[4] = b_list[4] + 1
                if value == "c" and class_list[index + 1] == "a":
                    c_list[0] = c_list[0] + 1
                if value == "c" and class_list[index + 1] == "b":
                    c_list[1] = c_list[1] + 1
                if value == "c" and class_list[index + 1] == "c":
                    c_list[2] = c_list[2] + 1
                if value == "c" and class_list[index + 1] == "d":
                    c_list[3] = c_list[3] + 1
                if value == "c" and class_list[index + 1] == "e":
                    c_list[4] = c_list[4] + 1
                if value == "d" and class_list[index + 1] == "a":
                    d_list[0] = d_list[0] + 1
                if value == "d" and class_list[index + 1] == "b":
                    d_list[1] = d_list[1] + 1
                if value == "d" and class_list[index + 1] == "c":
                    d_list[2] = d_list[2] + 1
                if value == "d" and class_list[index + 1] == "d":
                    d_list[3] = d_list[3] + 1
                if value == "d" and class_list[index + 1] == "e":
                    d_list[4] = d_list[4] + 1
                if value == "e" and class_list[index + 1] == "a":
                    e_list[0] = e_list[0] + 1
                if value == "e" and class_list[index + 1] == "b":
                    e_list[1] = e_list[1] + 1
                if value == "e" and class_list[index + 1] == "c":
                    e_list[2] = e_list[2] + 1
                if value == "e" and class_list[index + 1] == "d":
                    e_list[3] = e_list[3] + 1
                if value == "e" and class_list[index + 1] == "e":
                    e_list[4] = e_list[4] + 1
            transfer_frequency_matrix = np.mat([a_list, b_list, c_list, d_list, e_list])
            print(transfer_frequency_matrix)

            # 转移概率矩阵
            avg_transfer_frequency_ndarray = np.apply_along_axis(np.mean, axis=1, arr=transfer_frequency_matrix)
            transition_probability_matrix = None
            _list = list()
            for index, value in enumerate(transfer_frequency_matrix):
                _matrix = np.divide(value[0], avg_transfer_frequency_ndarray[index])
                _list.append(_matrix.tolist()[0])
            # 过滤NaN
            _list = nan_to_num(_list, nan=0.0)
            transition_probability_matrix = np.mat(_list)
            pass

    def do_task_2(self):
        """
        隐马尔科夫模型量化交易
        """

        turnover_list = list()
        close_price_list = list()
        highest_price_list = list()
        lowest_price_list = list()
        transaction_date_list = list()
        n = 6  # 隐藏状态个数
        code = "CU"

        # filter_dict = {"code": "PM"}
        # order_by_list = ['transaction_date']
        # commodity_future_date_data_list = self.commodity_future_date_data_service.commodity_future_date_data_dao.find_list(
        #     filter_dict, dict(), order_by_list)
        commodity_future_date_data_queryset = CommodityFutureDateData.objects.filter(code=code).order_by(
            'transaction_date')[3857:]
        if commodity_future_date_data_queryset is not None and len(commodity_future_date_data_queryset) > 0:
            for commodity_future_date_data in commodity_future_date_data_queryset:
                if math.isnan(commodity_future_date_data.turnover) or math.isnan(
                        commodity_future_date_data.close_price) or math.isnan(
                    commodity_future_date_data.highest_price) or math.isnan(
                    commodity_future_date_data.lowest_price):
                    continue
                if math.isinf(commodity_future_date_data.turnover) or math.isinf(
                        commodity_future_date_data.close_price) or math.isinf(
                    commodity_future_date_data.highest_price) or math.isinf(
                    commodity_future_date_data.lowest_price):
                    continue
                turnover_list.append(float(commodity_future_date_data.turnover))
                close_price_list.append(float(commodity_future_date_data.close_price))
                highest_price_list.append(float(commodity_future_date_data.highest_price))
                lowest_price_list.append(float(commodity_future_date_data.lowest_price))
                transaction_date_list.append(commodity_future_date_data.transaction_date)
            print(np.diff(np.log(np.array(close_price_list))))
            print(close_price_list[5:])
            print(close_price_list[:-5])
            print(turnover_list[5:])
            print(turnover_list[:-5])
            log_del = np.log(np.array(highest_price_list)) - np.log(np.array(lowest_price_list))
            log_return_1 = np.array(np.diff(np.log(np.array(close_price_list))))  # 这个作为后面计算收益使用
            log_return_5 = np.log(np.array(close_price_list[5:])) - np.log(np.array(close_price_list[:-5]))
            log_turnover_5 = np.log(np.array(turnover_list[5:])) - np.log(np.array(turnover_list[:-5]))
            log_del = log_del[5:]
            log_return_1 = log_return_1[4:]
            close_price_list = close_price_list[5:]
            Date = pd.to_datetime(transaction_date_list[5:])
            A = np.column_stack([log_del, log_return_5, log_turnover_5])
            print(A)

            where_are_nan = np.isnan(A).any()
            where_are_inf = np.isinf(A).any()
            print(np.all(where_are_nan))
            print(np.all(where_are_inf))
            # A = self.fill_ndarray(A)
            A = np.nan_to_num(A, posinf=1e10, neginf=-1e10)

            model = None
            try:
                model = GaussianHMM(n_components=n, covariance_type="full", n_iter=5000).fit(A)
            except ValueError:
                Logger.error('非正定矩阵，covariance_type="diag"')
                model = GaussianHMM(n_components=n, covariance_type="diag", n_iter=5000).fit(A)
            hidden_state_ndarray = model.predict(A)

            plt.figure(figsize=(25, 18))
            date_0_list = list()
            close_price_0_list = list()
            date_1_list = list()
            close_price_1_list = list()
            date_2_list = list()
            close_price_2_list = list()
            date_3_list = list()
            close_price_3_list = list()
            date_4_list = list()
            close_price_4_list = list()
            date_5_list = list()
            close_price_5_list = list()

            hidden_state_index = 0
            for hidden_state in hidden_state_ndarray:
                if hidden_state == 0:
                    date_0_list.append(Date[hidden_state_index])
                    close_price_0_list.append(close_price_list[hidden_state_index])
                if hidden_state == 1:
                    date_1_list.append(Date[hidden_state_index])
                    close_price_1_list.append(close_price_list[hidden_state_index])
                if hidden_state == 2:
                    date_2_list.append(Date[hidden_state_index])
                    close_price_2_list.append(close_price_list[hidden_state_index])
                if hidden_state == 3:
                    date_3_list.append(Date[hidden_state_index])
                    close_price_3_list.append(close_price_list[hidden_state_index])
                if hidden_state == 4:
                    date_4_list.append(Date[hidden_state_index])
                    close_price_4_list.append(close_price_list[hidden_state_index])
                if hidden_state == 5:
                    date_5_list.append(Date[hidden_state_index])
                    close_price_5_list.append(close_price_list[hidden_state_index])
                hidden_state_index = hidden_state_index + 1

            # 创建一个新的图表
            fig, ax = plt.subplots(figsize=(22, 10))

            # 关闭交互模式
            plt.ioff()

            mpl.rcParams["font.sans-serif"] = ["SimHei"]
            mpl.rcParams['axes.unicode_minus'] = False

            # 创建折线图
            plt.plot(date_0_list, close_price_0_list, label='状态0', linewidth=0.0, linestyle='--',
                     marker='o', color='red', markersize='7')
            plt.plot(date_1_list, close_price_1_list, label='状态1', linewidth=0.0, linestyle='--',
                     marker='o', color='black', markersize='7')
            plt.plot(date_2_list, close_price_2_list, label='状态2', linewidth=0.0, linestyle='--',
                     marker='o', color='blue', markersize='7')
            plt.plot(date_3_list, close_price_3_list, label='状态3', linewidth=0.0, linestyle='--',
                     marker='o', color='green', markersize='7')
            plt.plot(date_4_list, close_price_4_list, label='状态4', linewidth=0.0, linestyle='--',
                     marker='o', color='purple', markersize='7')
            plt.plot(date_5_list, close_price_5_list, label='状态5', linewidth=0.0, linestyle='--',
                     marker='o', color='brown', markersize='7')
            plt.legend()
            plt.title('隐马尔可夫')
            plt.xlabel('时间')
            plt.ylabel('收盘价')
            # 保存图表
            plt.savefig("hmm.png")
            # plt.show()
            # 关闭图表
            plt.close()

            # for i in range(model.n_components):
            #     pos = (hidden_state_ndarray == i)
            #     plt.plot_date(Date[pos], close_price_list[pos], 'o', label='hidden state %d' % i, lw=2)
            #     plt.legend()
            # plt.show()

    def fill_ndarray(self, t1):
        for i in range(t1.shape[1]):  # 遍历每一列（每一列中的nan替换成该列的均值）
            temp_col = t1[:, i]  # 当前的一列
            nan_num = np.count_nonzero(temp_col != temp_col)
            if nan_num != 0:  # 不为0，说明当前这一列中有nan
                temp_not_nan_col = temp_col[temp_col == temp_col]  # 去掉nan的ndarray

                # 选中当前为nan的位置，把值赋值为不为nan的均值
                temp_col[np.isnan(temp_col)] = temp_not_nan_col.mean()  # mean()表示求均值。
                temp_col[np.isinf(temp_col)] = temp_not_nan_col.mean()
        return t1

    def predict(self):
        """
        预测未来收盘价
        """

        n = 6  # 隐藏状态个数
        code = "CU"
        training_data_list: List[CommodityFutureDateData]
        test_data_list: List[CommodityFutureDateData]
        test_data_number: int = 100

        commodity_future_date_data_queryset = CommodityFutureDateData.objects.values('close_price', 'turnover',
                                                                                     'transaction_date').filter(
            code=code).order_by(
            'transaction_date')[:]
        if commodity_future_date_data_queryset is not None and len(commodity_future_date_data_queryset) > 0:
            commodity_future_date_data_list = list(commodity_future_date_data_queryset)
            df = pd.DataFrame(commodity_future_date_data_list)
            print("原始数据的大小：", df.shape)
            print("原始数据的列名", df.columns)

            df['transaction_date'] = pd.to_datetime(df['transaction_date'])
            df.reset_index(inplace=True, drop=False)
            df.drop(['index'], axis=1, inplace=True)
            print(df.head())
            dates = df['transaction_date'][1:]
            close_v = df['close_price']
            volume = df['turnover'][1:].astype('float')
            diff = np.diff(close_v).astype('float')
            # 获得输入数据
            # X = np.column_stack([diff, df['close_price'][1:].astype('float')])
            X = np.column_stack([diff, volume])
            print("输入数据的大小：", X.shape)

            min = float(X.mean(axis=0, dtype=float)[0]) - float(8 * X.std(axis=0, dtype=float)[0])  # 最小值
            max = float(X.mean(axis=0, dtype=float)[0]) + float(8 * X.std(axis=0, dtype=float)[0])  # 最大值
            X = pd.DataFrame(X)
            # 异常值设为均值
            for i in range(len(X)):  # dataframe的遍历
                if (X.loc[i, 0] < min) | (X.loc[i, 0] > max):
                    X.loc[i, 0] = X.mean(axis=0)[0]

            # 数据集的划分
            X_Train = X.iloc[:-test_data_number]
            X_Test = X.iloc[-test_data_number:]
            print("训练集的大小：", X_Train.shape)
            print("测试集的大小：", X_Test.shape)

            model = GaussianHMM(n_components=n, covariance_type='diag', n_iter=1000, min_covar=0.1)
            model.fit(X_Train)

            expected_returns_volumes = np.dot(model.transmat_, model.means_)
            expected_returns = expected_returns_volumes[:, 0]
            predicted_price = []  # 预测值
            current_price = close_v.iloc[-test_data_number]
            for i in range(len(X_Test)):
                hidden_states = model.predict(X_Test.iloc[i].values.reshape(1, 2))  # 将预测的第一组作为初始值
                print("hidden_states：", hidden_states)
                predicted_price.append(float(current_price) + float(expected_returns[hidden_states]))
                current_price = predicted_price[i]

            x = dates[-(test_data_number - 1):]
            y_act = close_v[-(test_data_number - 1):]
            y_pre = pd.Series(predicted_price[:-1])
            plt.figure(figsize=(8, 6))
            plt.plot_date(x, y_act, linestyle="-", marker="o", color='g')
            plt.plot_date(x, y_pre, linestyle="-", marker="*", color='r')
            plt.legend(['Actual', 'Predicted'])
            plt.show()

    def training_all_commodity_future_with_few_eigenvalue(self):
        """
        hmm算法，测试所有期货（主力连续）数据：生成直方图、在训练数据集中学习，在测试数据集中测试。（使用少量特征值）
        """

        # do_short_status, do_short_status = self.model_hmm_service.training_all_commodity_future_with_few_eigenvalue()
        self.model_hmm_service.training_all_commodity_future_with_few_eigenvalue()

    def training_all_commodity_future_with_many_eigenvalue(self):
        """
        hmm算法，测试所有期货（主力连续）数据：生成直方图、在训练数据集中学习，在测试数据集中测试。（使用尽可能多的特征值）
        """

        self.model_hmm_service.training_all_commodity_future_with_many_eigenvalue()

    def training_all_etf_transaction_data_with_many_eigenvalue(self):
        """
        hmm算法，测试所有ETF数据：生成直方图、在训练数据集中学习，在测试数据集中测试。（使用尽可能多的特征值）
        """

        self.model_hmm_service.training_all_etf_transaction_data_with_many_eigenvalue()

    def create_hmm_profit_loss_and_close_price_line_picture(self):
        """
        创建hmm算法的收益和收盘价折线图
        """

        self.model_hmm_service.create_hmm_profit_loss_and_close_price_line_picture()

    def analysis_hmm_result(self):
        """
        分析hmm算法的结果
        """

        self.model_hmm_service.analysis_hmm_result()
